# Everything remains the same from the original Mujoco repo, besides flipping
# coordinates of the texture u and v.

# Copyright 2023 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""CLI for converting legacy MSH files to Wavefront OBJ files.

Usage:
  python -m mujoco.msh2obj -i <msh_file> -o <obj_file>
"""

import argparse
import dataclasses
import io
import pathlib

import numpy as np


@dataclasses.dataclass(frozen=True)
class Msh:
    """MuJoCo legacy binary msh file."""

    vertex_positions: np.ndarray
    vertex_normals: np.ndarray
    vertex_texcoords: np.ndarray
    face_vertex_indices: np.ndarray

    @staticmethod
    def create(file: pathlib.Path) -> "Msh":
        """Create a Msh object from a .msh file."""
        if not file.exists():
            raise FileNotFoundError(f"{file} does not exist.")

        with open(file, "rb") as f:
            nvertex = np.fromfile(f, dtype=np.int32, count=1)[0]
            nnormal = np.fromfile(f, dtype=np.int32, count=1)[0]
            ntexcoord = np.fromfile(f, dtype=np.int32, count=1)[0]
            nface = np.fromfile(f, dtype=np.int32, count=1)[0]
            vertex_positions = np.fromfile(f, dtype=np.float32, count=3 * nvertex)
            vertex_normals = np.fromfile(f, dtype=np.float32, count=3 * nnormal)
            vertex_texcoords = np.fromfile(f, dtype=np.float32, count=2 * ntexcoord)
            face_vertex_indices = np.fromfile(f, dtype=np.int32, count=3 * nface)

        if vertex_positions.size != 3 * nvertex:
            raise ValueError(f"Invalid number of vertices: {vertex_positions.size} != 3*{nvertex}.")
        if vertex_normals.size != 3 * nnormal:
            raise ValueError(f"Invalid number of normals: {vertex_normals.size} != 3*{nnormal}.")
        if vertex_texcoords.size != 2 * ntexcoord:
            raise ValueError(f"Invalid number of texcoords: {vertex_texcoords.size} != 2*{{ntexcoord}}.")
        if face_vertex_indices.size != 3 * nface:
            raise ValueError(f"Invalid number of faces: {face_vertex_indices.size} != 3*{nface}.")

        vertex_positions = vertex_positions.reshape(-1, 3)
        vertex_normals = vertex_normals.reshape(-1, 3)
        face_vertex_indices = face_vertex_indices.reshape(-1, 3)

        # Reshape texture coordinates and swap x,y while maintaining the flip
        vertex_texcoords = vertex_texcoords.reshape(-1, 2)
        vertex_texcoords = vertex_texcoords[:, ::-1]  # Swap U and V
        vertex_texcoords[:, 1] = 1 - vertex_texcoords[:, 1]  # Flip V coordinate

        return Msh(
            vertex_positions=vertex_positions,
            vertex_normals=vertex_normals,
            vertex_texcoords=vertex_texcoords,
            face_vertex_indices=face_vertex_indices,
        )


def msh_to_obj(msh_file: pathlib.Path) -> str:
    """Convert a legacy .msh file to the .obj format."""
    msh = Msh.create(msh_file)

    out = io.StringIO()
    for x, y, z in msh.vertex_positions:
        out.write(f"v {x} {y} {z}\n")
    for x, y, z in msh.vertex_normals:
        out.write(f"vn {x} {y} {z}\n")
    for u, v in msh.vertex_texcoords:
        # Write texture coordinates directly since we already swapped them
        out.write(f"vt {u} {v}\n")
    for i, j, k in msh.face_vertex_indices:
        out.write(f"f {i + 1}/{i + 1}/{i + 1} {j + 1}/{j + 1}/{j + 1} {k + 1}/{k + 1}/{k + 1}\n")

    return out.getvalue()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("-i", "--input", type=str, help="Path to the msh file.")
    parser.add_argument("-o", "--output", type=str, help="Path to the obj file.")
    args = parser.parse_args()
    with open(pathlib.Path(args.output), "w") as f:
        f.write(msh_to_obj(pathlib.Path(args.input)))
